{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Import libs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/xiangjx/anaconda3/envs/musk/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from timm.models import create_model\n",
    "from musk import utils, modeling\n",
    "from PIL import Image\n",
    "from transformers import XLMRobertaTokenizer\n",
    "from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD\n",
    "import torchvision\n",
    "from huggingface_hub import login\n",
    "login(<HF Token>)\n",
    "device = torch.device(\"cuda:2\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extract Image Embeddings\n",
    "\n",
    "- Set `ms_aug = True` for:  \n",
    "  - Linear probe classification  \n",
    "  - Multiple Instance Learning  \n",
    "\n",
    "- Set `ms_aug = False` for:  \n",
    "  - Zero-shot tasks (e.g., image-image retrieval and image-text retrieval)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load ckpt from hf_hub:xiangjx/musk\n",
      "torch.Size([1, 2048])\n"
     ]
    }
   ],
   "source": [
    "# >>>>>>>>>>>> load model >>>>>>>>>>>> #\n",
    "model_config = \"musk_large_patch16_384\"\n",
    "model = create_model(model_config).eval()\n",
    "utils.load_model_and_may_interpolate(\"hf_hub:xiangjx/musk\", model, 'model|module', '')\n",
    "model.to(device, dtype=torch.float16)\n",
    "model.eval()\n",
    "# <<<<<<<<<<<< load model <<<<<<<<<<<< #\n",
    "\n",
    "# >>>>>>>>>>>> process image >>>>>>>>>>> #\n",
    "# load an image and process it\n",
    "img_size = 384\n",
    "transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(img_size, interpolation=3, antialias=True),\n",
    "    torchvision.transforms.CenterCrop((img_size, img_size)),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)\n",
    "])\n",
    "\n",
    "img = Image.open('assets/lungaca1014.jpeg').convert(\"RGB\")  # input image\n",
    "img_tensor = transform(img).unsqueeze(0)\n",
    "with torch.inference_mode():\n",
    "    image_embeddings = model(\n",
    "        image=img_tensor.to(device, dtype=torch.float16),\n",
    "        with_head=False, # We only use the retrieval head for image-text retrieval tasks.\n",
    "        out_norm=True,\n",
    "        ms_aug=True  # by default it is False, `image_embeddings` will be 1024-dim; if True, it will be 2048-dim.\n",
    "        )[0]  # return (vision_cls, text_cls)\n",
    "\n",
    "print(image_embeddings.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Text-Image Retrieval\n",
    "### Zero-shot image classification follows the same configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Load ckpt from hf_hub:xiangjx/musk\n",
      "tensor([[0.3782, 0.3247, 0.2969]], device='cuda:2', dtype=torch.float16)\n"
     ]
    }
   ],
   "source": [
    "# >>>>>>>>>>>> load model >>>>>>>>>>>> #\n",
    "model_config = \"musk_large_patch16_384\"\n",
    "model = create_model(model_config).eval()\n",
    "utils.load_model_and_may_interpolate(\"hf_hub:xiangjx/musk\", model, 'model|module', '')\n",
    "model.to(device, dtype=torch.float16)\n",
    "model.eval()\n",
    "# <<<<<<<<<<<< load model <<<<<<<<<<<< #\n",
    "\n",
    "# >>>>>>>>>>>> process image >>>>>>>>>>> #\n",
    "# load an image and process it\n",
    "img_size = 384\n",
    "transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(img_size, interpolation=3, antialias=True),\n",
    "    torchvision.transforms.CenterCrop((img_size, img_size)),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD)\n",
    "])\n",
    "\n",
    "img = Image.open('assets/lungaca1014.jpeg').convert(\"RGB\")  # input image\n",
    "img_tensor = transform(img).unsqueeze(0)\n",
    "with torch.inference_mode():\n",
    "    image_embeddings = model(\n",
    "        image=img_tensor.to(device, dtype=torch.float16),\n",
    "        with_head=True,  # We utilize this head for zero-shot tasks (image-text retrieval and zero-shot image classification).\n",
    "        out_norm=True    # Ensure that the embedding is normalized.\n",
    "        )[0]\n",
    "# <<<<<<<<<<< process image <<<<<<<<<<< #\n",
    "\n",
    "# >>>>>>>>>>> process language >>>>>>>>> #\n",
    "# load tokenzier for language input\n",
    "tokenizer = XLMRobertaTokenizer(\"./musk/models/tokenizer.spm\")\n",
    "labels = [\"lung adenocarcinoma\",\n",
    "            \"benign lung tissue\",\n",
    "            \"lung squamous cell carcinoma\"]\n",
    "\n",
    "texts = ['histopathology image of ' + item for item in labels]\n",
    "text_ids = []\n",
    "paddings = []\n",
    "for txt in texts:\n",
    "    txt_ids, pad = utils.xlm_tokenizer(txt, tokenizer, max_len=100)\n",
    "    text_ids.append(torch.tensor(txt_ids).unsqueeze(0))\n",
    "    paddings.append(torch.tensor(pad).unsqueeze(0))\n",
    "\n",
    "text_ids = torch.cat(text_ids)\n",
    "paddings = torch.cat(paddings)\n",
    "with torch.inference_mode():\n",
    "    text_embeddings = model(\n",
    "        text_description=text_ids.to(device),\n",
    "        padding_mask=paddings.to(device),\n",
    "        with_head=True, # We utilize this head for zero-shot tasks (image-text retrieval and zero-shot image classification).\n",
    "        out_norm=True   # Ensure that the embedding is normalized.\n",
    "    )[1]\n",
    "# <<<<<<<<<<<< process language <<<<<<<<<<< #\n",
    "\n",
    "# >>>>>>>>>>>>> calculate similarity >>>>>>> #\n",
    "with torch.inference_mode():\n",
    "    # expected prob:[0.3782, 0.3247, 0.2969]  --> lung adenocarcinoma\n",
    "    sim = model.logit_scale * image_embeddings @ text_embeddings.T\n",
    "    prob = sim.softmax(dim=-1)\n",
    "    print(prob)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "musk",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
